{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fe548d4c-0ae6-4490-85e9-9a08b324e523",
   "metadata": {},
   "source": [
    "# K Nearest Neighbors"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1dffa51f",
   "metadata": {
    "height": 98
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import matplotlib.pyplot as plt\n",
    "from sklearn.neighbors import NearestNeighbors\n",
    "import time\n",
    "np.random.seed(42)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c6507163-f256-4969-b39a-42e4fe87e16e",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "# Generate 20 data points with 2 dimensions\n",
    "X = np.random.rand(20,2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d57be60-f38c-40ec-9a79-6abe351ff8e1",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "# Display Embeddings\n",
    "n = range(len(X))\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.scatter(X[:,0], X[:,1], label='Embeddings')\n",
    "ax.legend()\n",
    "\n",
    "for i, txt in enumerate(n):\n",
    "    ax.annotate(txt, (X[i,0], X[i,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f7607bd1-41e2-4f8b-9c25-9deba45c1bc4",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "k = 4\n",
    "\n",
    "neigh = NearestNeighbors(n_neighbors=k, algorithm='brute', metric='euclidean')\n",
    "neigh.fit(X)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0dfb6e89-cfb6-4ef3-a5b9-b8ba65cfe7b1",
   "metadata": {
    "height": 182
   },
   "outputs": [],
   "source": [
    "# Display Query with data\n",
    "n = range(len(X))\n",
    "\n",
    "fig, ax = plt.subplots()\n",
    "ax.scatter(X[:,0], X[:,1])\n",
    "ax.scatter(0.45,0.2, c='red',label='Query')\n",
    "ax.legend()\n",
    "\n",
    "for i, txt in enumerate(n):\n",
    "    ax.annotate(txt, (X[i,0], X[i,1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cd20c13b-c830-47fc-9bde-824b81607ebc",
   "metadata": {
    "height": 63
   },
   "outputs": [],
   "source": [
    "neighbours = neigh.kneighbors([[0.45,0.2]], k, return_distance=True)\n",
    "print(neighbours)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "31e7e76f-f26e-46a5-9778-8e1aa46e18b5",
   "metadata": {
    "height": 131
   },
   "outputs": [],
   "source": [
    "t0 = time.time()\n",
    "neighbours = neigh.kneighbors([[0.45,0.2]], k, return_distance=True)\n",
    "t1 = time.time()\n",
    "\n",
    "query_time = t1-t0\n",
    "print(f\"Runtime: {query_time: .4f} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4827c91-797a-469d-9679-1dbba2e864d1",
   "metadata": {
    "height": 335
   },
   "outputs": [],
   "source": [
    "def speed_test(count):\n",
    "    # generate random objects\n",
    "    data = np.random.rand(count,2)\n",
    "    \n",
    "    # prepare brute force index\n",
    "    k=4\n",
    "    neigh = NearestNeighbors(n_neighbors=k, algorithm='brute', metric='euclidean')\n",
    "    neigh.fit(data)\n",
    "\n",
    "    # measure time for a brute force query\n",
    "    t0 = time.time()\n",
    "    neighbours = neigh.kneighbors([[0.45,0.2]], k, return_distance=True)\n",
    "    t1 = time.time()\n",
    "\n",
    "    total_time = t1-t0\n",
    "    print (f\"Runtime: {total_time: .4f}\")\n",
    "\n",
    "    return total_time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cae5ed6a-9fb8-4fc3-85c6-b9eff573ad5a",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": [
    "time20k = speed_test(20_000)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1f114314-b875-40b0-be81-bae6b1762765",
   "metadata": {
    "height": 97
   },
   "outputs": [],
   "source": [
    "# Brute force examples\n",
    "time200k = speed_test(200_000)\n",
    "time2m = speed_test(2_000_000)\n",
    "time20m = speed_test(20_000_000)\n",
    "time200m = speed_test(200_000_000)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1760c511-c422-42dc-a41c-b125623a8619",
   "metadata": {},
   "source": [
    "## Brute force kNN implemented by hand on `768` dimensional embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d140b5d1-e110-4a42-9e3c-00beab03694d",
   "metadata": {
    "height": 165
   },
   "outputs": [],
   "source": [
    "documents = 1000\n",
    "dimensions = 768\n",
    "\n",
    "embeddings = np.random.randn(documents, dimensions) # 1000 documents, 768-dimensional embeddings\n",
    "embeddings = embeddings / np.sqrt((embeddings**2).sum(1, keepdims=True)) # L2 normalize the rows, as is common\n",
    "\n",
    "query = np.random.randn(768) # the query vector\n",
    "query = query / np.sqrt((query**2).sum()) # normalize query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e8cfca36-3b4e-4103-bae1-98939b48c78c",
   "metadata": {
    "height": 267
   },
   "outputs": [],
   "source": [
    "# kNN\n",
    "t0 = time.time()\n",
    "# Calculate Dot Product between the query and all data items\n",
    "similarities = embeddings.dot(query)\n",
    "# Sort results\n",
    "sorted_ix = np.argsort(-similarities)\n",
    "t1 = time.time()\n",
    "\n",
    "total = t1-t0\n",
    "print(f\"Runtime for dim={dimensions}, documents_n={documents}: {np.round(total,3)} seconds\")\n",
    "\n",
    "print(\"Top 5 results:\")\n",
    "for k in sorted_ix[:5]:\n",
    "    print(f\"Point: {k}, Similarity: {similarities[k]}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "51493604-ca06-4129-a026-007b6ecbb478",
   "metadata": {
    "height": 250
   },
   "outputs": [],
   "source": [
    "n_runs = [1_000, 10_000, 100_000, 500_000]\n",
    "\n",
    "for n in n_runs:\n",
    "    embeddings = np.random.randn(n, dimensions) #768-dimensional embeddings\n",
    "    query = np.random.randn(768) # the query vector\n",
    "    \n",
    "    t0 = time.time()\n",
    "    similarities = embeddings.dot(query)\n",
    "    sorted_ix = np.argsort(-similarities)\n",
    "    t1 = time.time()\n",
    "\n",
    "    total = t1-t0\n",
    "    print(f\"Runtime for 1 query with dim={dimensions}, documents_n={n}: {np.round(total,3)} seconds\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9748863e-5640-42c4-8223-3d0897867c58",
   "metadata": {
    "height": 46
   },
   "outputs": [],
   "source": [
    "print (f\"To run 1,000 queries: {total * 1_000/60 : .2f} minutes\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "deeba5ab-f767-44ac-8c37-47a46f26f1f4",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c1d6eb5a-7eef-49db-94c2-4f81085aa319",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef1d12f2-f4d5-4bea-9268-f3d38de1d01d",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f67c4fb3-56a4-4b5e-b3e6-e8517a9958b8",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e49f3e52-32ad-46d8-9222-e17db1822ead",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "27b08d45-4c19-4123-b0f5-948785810347",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4fe5c54c-f7ba-43dd-a9e2-006f180d19b8",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d2f19487-55fe-4fd8-8895-d0fbd9d408fa",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e14c073d-f51d-4229-bb18-bf2ca33b6a45",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "136a9eaa-f548-419b-a280-f75fc0035c4b",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "daee8bd2-2a65-40c1-a644-abb59f8ea5a3",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e32b66c0-3522-4aea-ab08-f29ef5138169",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "d069b75f-6549-475b-bbb4-40c0a93e01b4",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "dbd4335f-f2d1-4725-a4c0-6a12f79530ee",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1c6a0e14-34d6-4c6b-9ee6-8f3e72547dae",
   "metadata": {
    "height": 29
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.19"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
